package com.downguys.tool.netty;

import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.handler.ssl.SslProvider;
import io.netty.handler.ssl.util.InsecureTrustManagerFactory;
import io.netty.handler.ssl.util.SelfSignedCertificate;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.net.ssl.SSLException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

/**
 * 基于Netty的SslContextBuilder实现的工厂类，目的是重用SslContext
 * 
 * 删除了getclient()空入参的默认方法，由调用者显式指定SslProvider
 * 
 * @author tony04.liu
 */
public class SslContextFactory {
	private static final Logger logger = LoggerFactory.getLogger(SslContextFactory.class);

	private static volatile SslContext clientJDK; // 基于JDK的SSL实现
	private static volatile SslContext serverJDK; // 基于JDK的SSL实现

	private static volatile SslContext clientOpenSSL; // 基于OpenSSL的SSL实现
	private static volatile SslContext serverOpenSSL; // 基于OpenSSL的SSL实现
	private static volatile SslContext serverHttpOpenSSL; // 基于OpenSSL的SSL实现
	// ------------------------------------------------------------------------------------------------------------

	/** 清除之前实例化的client、server */
	public static void clear() {
		clientJDK = null;
		serverJDK = null;

		clientOpenSSL = null;
		serverOpenSSL = null;
	}

	/**
	 * 获取全局共享的SslContext client实例
	 * 
	 * @param provider SSL的实现方式，包括：JDK、OpenSSL
	 */
	public static SslContext getClientContext(SslProvider provider) throws Exception {
		if (getClient0(provider) == null) {
			synchronized (SslContext.class) {
				if (getClient0(provider) == null) {
					try {
						long start = System.currentTimeMillis(); // 统计耗时
						if (SslProvider.JDK == provider) {
							clientJDK = SslContextBuilder.forClient().sslProvider(provider)
									.trustManager(InsecureTrustManagerFactory.INSTANCE).build();

						} else {
							clientOpenSSL = SslContextBuilder.forClient().sslProvider(provider)
									.trustManager(InsecureTrustManagerFactory.INSTANCE).build();
						}
						logger.info("create SslContext(SslProvider:" + provider + ") for client consume："
								+ (System.currentTimeMillis() - start)); // 统计耗时

					} catch (SSLException e) {
						throw new Exception(
								"get SslContext(SslProvider:" + provider + ") for client fail.", e);
					}
				}
			}
		}

		return getClient0(provider);
	}

	// ------------------------------------------------------------------------------------------------------------
	/**
	 * 获取全局共享的SslContext server实例
	 * 
	 * @param provider SSL的实现方式，包括：JDK、OpenSSL
	 */
	public static SslContext getServerContext(SslProvider provider, boolean isHttps) throws Exception {
		if (getServer0(provider, isHttps) == null) {
			synchronized (SslContext.class) {
				if (getServer0(provider, isHttps) == null) {
					try {
						long start = System.currentTimeMillis(); // 统计耗时
						SelfSignedCertificate ssc = SelfSignedCertificateFactory.getInstance();

						if (SslProvider.JDK == provider) {
							serverJDK = SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey())
									.sslProvider(provider).build();
						} else if (SslProvider.OPENSSL == provider) {
							if (isHttps) {
								if (SslProviderUtil.isHttpsUseGcm()) {
									//如果支持GCM，则跟平时一样
									serverHttpOpenSSL = SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey())
											.sslProvider(provider).build();
								} else {
									// HTTPS特别处理， 从ReferenceCountedOpenSslContext的DEFAULT_CIPHERS中删除"AES128-GCM-SHA256", "ECDHE-RSA-AES128-GCM-SHA256",
									List<String> ciphers = new ArrayList<String>();
									Collections.addAll(ciphers, "ECDHE-RSA-AES128-SHA", "ECDHE-RSA-AES256-SHA",
											"AES128-SHA", "AES256-SHA", "DES-CBC3-SHA");
									serverHttpOpenSSL = SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey())
											.sslProvider(provider).ciphers(ciphers).build();
								}
							} else {
								serverOpenSSL = SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey())
										.sslProvider(provider).build();
							}

						}
						logger.info("create SslContext(SslProvider:" + provider + ") for server consume："
								+ (System.currentTimeMillis() - start)); // 统计耗时
					} catch (SSLException e) {
						throw new Exception(
								"get SslContext(SslProvider:" + provider + ") for server fail.", e);
					}
				}
			}
		}

		return getServer0(provider, isHttps);
	}

	// ------------------------------------------------------------------------------------------------------------
	private static SslContext getClient0(SslProvider provider) {
		if (SslProvider.JDK == provider) {
			return clientJDK;
		} else if (SslProvider.OPENSSL == provider) {
			return clientOpenSSL;
		} else {
			return null;
		}
	}

	// ------------------------------------------------------------------
	private static SslContext getServer0(SslProvider provider, boolean isHttps) {
		if (SslProvider.JDK == provider) {
			return serverJDK;
		} else if (SslProvider.OPENSSL == provider) {
			if (isHttps) {
				return serverHttpOpenSSL;
			} else {
				return serverOpenSSL;
			}
		} else {
			return null;
		}
	}

	// ------------------------------------------------------------------
}
